/*
  Copyright 2006-2012 Patrik Jonsson, sunrise@familjenjonsson.org

  This file is part of Sunrise.

  Sunrise is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 3 of the License, or
  (at your option) any later version.

  Sunrise is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with Sunrise.  If not, see <http://www.gnu.org/licenses/>.
  
*/

/** \file
    Defines the skeleton for running an analytic setup. The user
    includes this file and defines functions that sets up the problem
    and analyzes the results. 

    The functions to define are description(), add_custom_options(),
    print_custom_options() and setup(). setup() create the factory,
    emission, emergence, and dust_model objects and then calls
    run_case which does the calculation.

*/

#ifndef __analytic_case__
#define __analytic_case__

#include "xfer.h"
#include "mlt_xfer.h"
#include "random.h"
#include "chromatic_policy.h"
#include "dust_model.h"
#include "full_sed_grid.h"
#include "ir_grid.h"
#include "equilibrium.h"
#include "emission.h"
#include "grid_factory.h"
#include "grain_model.h"
#include "shoot.h"
#include "imops.h"
#include "emergence-fits.h"
#include "grain_model_hack.h"
#include "boost/program_options.hpp"
#include "CCfits/CCfits"
#include "misc.h"

using namespace mcrx;
using blitz::Range;
namespace po = boost::program_options;

typedef local_random T_rng_policy;
typedef scatterer<polychromatic_scatterer_policy, T_rng_policy> T_scatterer;
typedef T_scatterer::T_lambda T_lambda;
typedef T_scatterer::T_biaser T_biaser;
typedef dust_model<T_scatterer, cumulative_sampling, T_rng_policy> T_dust_model;
typedef full_sed_emergence T_emergence;

typedef T_full_sed_adaptive_grid T_grid;
typedef T_ir_adaptive_grid T_ir_grid;
typedef emission<polychromatic_policy, T_rng_policy> T_emission;
typedef xfer<T_dust_model, T_grid> T_xfer;

// prevent instantiation of the xfer template here, it's already in
//libmcrx. this lowers the -O0 compilation time from 6min to 2m...
extern template class mcrx::xfer<T_dust_model, T_grid>;
extern template class mcrx::full_sed_grid<typename T_grid::T_grid_impl>;


std::string description();
void add_custom_options(po::options_description&);
void print_custom_options(const po::variables_map&);
void setup(po::variables_map&);
void pre_shoot(T_xfer& x);

T_emergence extract_camera(const T_emergence& em, int c, int copies)
{
  std::vector<const T_emergence::T_camera*> cams;
  for(int i=0; i<copies; ++i)
    cams.push_back(&(em.get_camera(c)));
  T_emergence this_cam(cams);
  return this_cam;
}

void shoot_mlt(po::variables_map opt, 
	       T_grid& grid,
	       T_emission& emission,
	       T_emergence& cameras,
	       T_dust_model& model,
	       const T_lambda& lambda,
	       std::vector<T_rng::T_state> rng_states)
{
  const long n_rays =opt["n_rays"].as<long>();
  const long n_mutations =opt["n_mut"].as<long>();
  const long n_metropolis =opt["n_metro"].as<long>();
  const long n_seeds =opt["n_seeds"].as<long>();
  const int minscat = opt["nscat_min"].as<int>();
  const int maxscat = opt["nscat_max"].as<int>();

  std::vector<T_float> lambdas(lambda.begin(), lambda.end());
  const int reflambda = 
    lower_bound(lambdas.begin(), lambdas.end(),
		opt["reflambda"].as<T_float>()) - lambdas.begin()-1;
  
  const int max_n_track= 4;
  const int n_cams=(max_n_track+1)*(max_n_track+2)/2+1;

  for (typename T_emergence::iterator c = cameras.begin(); c != cameras.end(); ++c) {
    (*c)->allocate(emission.zero_lambda());
  }

  // since this only works with one camera, we need to repeat
  // calculation for each camera
  polychromatic_biaser b(reflambda);
  for(int c=0; c<cameras.n_cameras(); ++c) {
    T_emergence this_cam= extract_camera(cameras,c, n_cams);

    T_rng_policy rng;
    mlt_xfer<T_dust_model, T_grid> x (grid, emission, this_cam, model, rng,
				      polychromatic_biaser(reflambda),
				      minscat, maxscat);
    // ensure rngs are in sync with what we get in shoot()
    x.rng().make_unique(0);
    x.rng().get_generator().setState(rng_states[0]);
    T_float mean_weight = 
      x.metropolis_driver(n_seeds, n_metropolis, n_mutations);
    long n_rays_actual = 0;
    n_rays_actual=n_metropolis*n_mutations;

    // here we have to manually apply the normalization factor
    cameras.get_camera(c).get_image() = 
      this_cam.get_camera(0).get_image()*(mean_weight/n_rays_actual);
  }
}

boost::shared_ptr<T_emergence>
shoot_bidir(po::variables_map opt, 
	    T_grid& grid,
	    T_emission& emission,
	    T_emergence& cameras,
	    T_dust_model& model,
	    const T_lambda& lambda,
	    std::vector<T_rng::T_state> rng_states)
{
  const long n_rays =opt["n_rays"].as<long>();
  const int ns_max = opt["ns_max"].as<int>();
  const int nc_max = opt["nc_max"].as<int>();
  const int minscat = opt["nscat_min"].as<int>();
  const int maxscat = opt["nscat_max"].as<int>();
  const int n_threads = opt["n_threads"].as<int>();

  std::vector<T_float> lambdas(lambda.begin(), lambda.end());
  const int reflambda = 
    lower_bound(lambdas.begin(), lambdas.end(),
		opt["reflambda"].as<T_float>()) - lambdas.begin()-1;

  // if there's more than one camera in the emergence object, we
  // obviously can't disaggregate the results by subpath lengths
  const int max_n_track= (cameras.n_cameras()>1) ? -1 : 4;
  const int n_cams= 
    (cameras.n_cameras()>1) ? 1 : (max_n_track+1)*(max_n_track+2)/2+1;

  boost::shared_ptr<T_emergence> save_em;
  if(cameras.n_cameras()>1)
    // just copy the emergence
    save_em.reset(new T_emergence(cameras));
  else
    save_em.reset(new T_emergence(extract_camera(cameras,0, n_cams)));

  for (typename T_emergence::iterator c = save_em->begin(); c != save_em->end(); ++c) {
    (*c)->allocate(emission.zero_lambda());
  }

  dummy_terminator t;

  // since this only works with one camera, we need to repeat
  // calculation for each camera
  for(int c=0; c<cameras.n_cameras(); ++c) {
    T_emergence this_cam= extract_camera(cameras,c, n_cams);
    T_rng_policy rng;
    mlt_xfer<T_dust_model, T_grid> x (grid, emission, this_cam, model, rng,
				      polychromatic_biaser(reflambda),
				      minscat, maxscat);
    x.max_n_track_ = max_n_track;
    long n_rays_actual = 0;
    shoot (x, bidir_shooter(ns_max, nc_max), t, rng_states, 
	   n_rays, n_rays_actual, n_threads, true, n_threads, 0, "");

    // here we have to manually apply the normalization factor
    for(int cc=0; cc<n_cams; ++cc)
      save_em->get_camera(c+cc).get_image() = 
	this_cam.get_camera(cc).get_image()*(1./n_rays_actual);
  }

  return save_em;
}


template<typename T_factory>
void run_case(po::variables_map opt, 
	      T_factory& factory,
	      const T_unit_map& units,
	      const T_lambda& lambda,
	      vec3d gridmin,
	      vec3d gridmax,
	      T_emission& emission,
	      T_emergence& cameras,
	      T_dust_model& model)
{
  const int n_threads = opt["n_threads"].as<int>();
  const long n_rays =opt["n_rays"].as<long>();
  const long n_rays_ir =opt["n_rays_ir"].as<long>();
  const long n_rays_temp = opt["n_rays_temp"].as<long>();
  const T_float i_max=opt["i_max"].as<T_float>();
  const T_float i_min=opt["i_min"].as<T_float>();
  const T_float ir_tol =opt["ir_tol"].as<T_float>();
  const T_float ir_lumfrac =opt["ir_lumfrac"].as<T_float>();
  const string output_file(word_expand(opt["output-file"].as<string>())[0]);
  const T_float nscat_min=opt["nscat_min"].as<int>();
  const T_float nscat_max=opt["nscat_max"].as<int>();
  const bool save_temps = opt["save_temps"].as<bool>();

  mcrx::seed(opt["seed"].as<int>());

  // create the adaptive grid
  boost::shared_ptr<adaptive_grid<T_grid::T_grid_impl::T_data> > gg
    (new adaptive_grid<T_grid::T_grid_impl::T_data>  
     (gridmin, gridmax, factory));

  // and then the full_sed_grid
  T_grid gr(gg, units);

  // calculate dust mass
  array_2 dust_mass(gr.n_cells(), model.n_scatterers());
  int cc=0;
  const T_float area_factor = gr.area_factor();
  for(T_grid::const_iterator c=gr.begin(); c!=gr.end(); ++c, ++cc) {
    dust_mass(cc, Range::all()) = 
      c->data()->get_absorber().densities()*c.volume();
  }
  cout << "Total dust mass: " << sum(dust_mass) << endl;
  
  cout << "shooting" << endl;

  std::vector<T_rng::T_state> states= generate_random_states (n_threads);
  T_rng_policy rng;
  T_biaser b;
  dummy_terminator t;

  // select RT method for primary radiation
  const string rt_method = opt["rt_method"].as<string>();
  T_emergence* save_cameras=&cameras;
  boost::shared_ptr<T_emergence> bidir_emergence;
  if(rt_method=="standard") {
    T_xfer x (gr, emission, cameras, model, rng, b,
	      i_min, i_max, false, false, nscat_min, nscat_max);
    
    pre_shoot(x);
    long n_rays_actual = 0;
    shoot (x, scatter_shooter(), t, states, n_rays,
	   n_rays_actual, n_threads, true, n_threads, 0, "");
  }
  else if(rt_method=="MLT") {
    shoot_mlt(opt, gr, emission, cameras, model, lambda, states);
  }
  else if(rt_method=="bidir") {
    bidir_emergence = 
      shoot_bidir(opt, gr, emission, cameras, model, lambda, states);
    save_cameras=bidir_emergence.get();
  }
  else {
    std::cerr<< "Unknown RT method: " << opt["rt_method"].as<string>() << std::endl;
    throw 1;
  }
  
  {
    CCfits::FITS output("!"+opt["output-file"].as<string>(), CCfits::Write);
    save_cameras->write_parameters(output, units);

    CCfits::Table* iq_hdu=output.addTable("INTEGRATED_QUANTITIES",0);
    iq_hdu->addColumn(CCfits::Tdouble, "lambda", 1,
		      units.get("L_lambda"));
    write(iq_hdu->column("lambda"),lambda,1);

    save_cameras->write_images(output, 1.0, units, "-STAR", false, false);

    // calculate integrated SEDs
    for (int i=0; i<save_cameras->n_cameras(); ++i) {
      integrate_image(output,
		      "CAMERA"+boost::lexical_cast<string>(i)+"-STAR",
		      "CAMERA"+boost::lexical_cast<string>(i)+"-PARAMETERS",
		      "L_lambda"+boost::lexical_cast<string>(i),
		      lambda,
		      "",
		      units,
		      dummy_terminator());
    }
  }

  if(!opt["do_ir"].as<bool>())
    return;

  // create ir grid
  T_ir_grid irg(gg, dust_mass, lambda, units);

  array_2 dust_intensity(gr.n_cells(), lambda.size());
  dust_intensity=0;

  // Determine dust equilibrium
  determine_dust_equilibrium_intensities
    (model,
     gr,
     irg,
     emission,
     lambda,
     b,
     dust_intensity,
     states,
     n_rays_temp,
     n_threads,
     true,
     n_threads,
     ir_tol,
     ir_lumfrac,
     i_min,
     i_max,
     t,
     false);

  // make vector of grain models from the scatterers so we can calculate emission
  std::vector<grain_model<polychromatic_scatterer_policy, 
    mcrx_rng_policy>*> grain_models;
  for(int i=0; i< model.n_scatterers(); ++i) {
    grain_models.push_back
      (dynamic_cast<grain_model<polychromatic_scatterer_policy, 
       mcrx_rng_policy>*>(&model.get_scatterer(i)));
    grain_models.back()->resample(lambda);
  }

  // zero out cameras for ir run
  for(typename T_emergence::iterator c= cameras.begin();
      c != cameras.end(); ++c)
    (*c)->get_image() = 0;

  // calculate SED without sampling normalization so we get correct
  // emissivity for integration.
  irg.calculate_SED(dust_intensity, grain_models, t, n_threads, 0, 
		    false, false, save_temps);
  gr.reset();
  for (typename T_emergence::iterator c = cameras.begin(); c != cameras.end(); ++c) {
    (*c)->allocate(irg.zero_lambda());
  }

  if(opt["integrate_ir"].as<bool>()) {
    T_xfer x2 (gr, irg, cameras, model, rng, b, 
	       i_min, i_max, false, false);
    
    cout << "Integrating source function\n";

    long dummy=0;
    x2.set_integrations_per_cell(opt["rays_per_cell"].as<int>());
    shoot(x2, integration_shooter(integration_shooter::intensity), t, states, 0, dummy, n_threads, true, n_threads, 0, "");

    // because the image values when integrating directly are already in
    // surface brightness, we need to back convert them to what's output
    // by the shooting so that they are consistent. To do this we must
    // multiply them by the normalized area of a pixel.
    for(typename T_emergence::iterator c= cameras.begin();
	c != cameras.end(); ++c)
      (*c)->get_image() *= (*c)->pixel_normalized_area();
    
  }

  // now shoot for the contribution from scattering (should be small)
  if (n_rays_ir>0) {
    irg.normalize_for_sampling();

    T_xfer x3 (gr, irg, cameras, model, rng, b,
	       i_min, i_max, false ,false,
	       opt["integrate_ir"].as<bool>() ? 1:0);

    if(opt["integrate_ir"].as<bool>())
      cout << "Shooting for scattered intensities\n";
    else
      cout << "Shooting for total dust emission intensities\n";
    
    long n_rays_actual = 0;
    shoot (x3, scatter_shooter(), t, states, n_rays_ir,
	   n_rays_actual, n_threads, true, n_threads, 0, "");
  }

  CCfits::FITS output(output_file, CCfits::Write);
  cameras.write_images(output,1.0,units, "-IR", false, false);

  // calculate integrated SEDs
  for (int i=0; i<cameras.n_cameras(); ++i) {
    integrate_image(output,
		    "CAMERA"+boost::lexical_cast<string>(i)+"-IR",
		    "CAMERA"+boost::lexical_cast<string>(i)+"-PARAMETERS",
		    "L_lambda_ir"+boost::lexical_cast<string>(i),
		    lambda,
		    "",
		    units,
		    dummy_terminator());
  }

  irg.write_seds(output, "CELLSEDS", (n_rays_ir>0) );

  if(save_temps)
    irg.write_temps(output, "DUSTTEMPS");
    
}

void print_common_options(const po::variables_map& opt)
{
  std::cout
    << "output-file = " << opt["output-file"].as<string>() << '\n'
    << "n_rays = " << opt["n_rays"].as<long>() << '\n'
    << "n_rays_temp = " << opt["n_rays_temp"].as<long>() << '\n'
    << "n_rays_ir = " << opt["n_rays_ir"].as<long>() << '\n'
    << "integrate_ir = " << opt["integrate_ir"].as<bool>() << '\n'
    << "rays_per_cell = " << opt["rays_per_cell"].as<int>() << '\n'
    << "save_temps = " << opt["save_temps"].as<bool>() << '\n'
    << "n_threads = " << opt["n_threads"].as<int>() << '\n'
    << "maxlevel = " << opt["maxlevel"].as<int>() << '\n'
    << "i_max = " << opt["i_max"].as<T_float>() << '\n'
    << "i_min = " << opt["i_min"].as<T_float>() << '\n'
    << "do_ir = " << opt["do_ir"].as<bool>() << '\n'
    << "ir_tol = " << opt["ir_tol"].as<T_float>() << '\n'
    << "ir_lumfrac = " << opt["ir_lumfrac"].as<T_float>() << '\n'
    << "tau_tol = " << opt["tau_tol"].as<T_float>() << '\n'
    << "lambda_min = " << opt["lambda_min"].as<T_float>() << '\n'
    << "lambda_max = " << opt["lambda_max"].as<T_float>() << '\n'
    << "n_lambda = " << opt["n_lambda"].as<int>() << '\n'
    << "seed = " << opt["seed"].as<int>() << '\n'
    << "wd01_parameter_set = " << opt["wd01_parameter_set"].as<string>() << '\n'
    << "nscat_min = " << opt["nscat_min"].as<int>() << '\n'
    << "nscat_max = " << opt["nscat_max"].as<int>() << '\n'
    << "rt_method = " << opt["rt_method"].as<string>() << '\n'
    << "n_seeds = " << opt["n_seeds"].as<long>() << '\n'
    << "n_metro = " << opt["n_metro"].as<long>() << '\n'
    << "n_mut = " << opt["n_mut"].as<long>() << '\n'
    << "nc_max = " << opt["nc_max"].as<int>() << '\n'
    << "ns_max = " << opt["ns_max"].as<int>() << '\n'
    << "reflambda = " << opt["reflambda"].as<T_float>() << '\n'
    ;
}


/** Adds commonly used radiation transfer options. */
void add_common_options(po::options_description& desc)
{
  desc.add_options()
    ("n_rays", po::value<long>()->default_value(1000000), "number of rays")
    ("n_rays_temp", po::value<long>()->default_value(1000000), 
     "number of rays for dust temp iteration")
    ("n_rays_ir", po::value<long>()->default_value(1000000), 
     "number of rays for final dust emission")
    ("integrate_ir", po::value<bool>()->default_value(true),
     "directly integrate the RTE for IR SED? (n_rays_ir then used for scattered contribution")
    ("rays_per_cell", po::value<int>()->default_value(2),
     "Number of integration rays per cell")
    ("save_temps", po::value<bool>()->default_value(false),
     "save dust temperatures for the cells")
    ("n_threads", po::value<int>()->default_value(12), "number of threads")
    ("maxlevel", po::value<int>()->default_value(6), 
     "max grid refinement level")
    ("i_max",  po::value<T_float>()->default_value(10),
     "max ray intensity")
    ("i_min", po::value<T_float>()->default_value(1e-2),
     "ray intensity where Russian Roulette starts")
    ("do_ir", po::value<bool>()->default_value(true), 
     "calculate dust emission")
    ("ir_tol", po::value<T_float>()->default_value(0.1), 
     "maximum variance of cell luminosities during equilibrium calculation")
    ("ir_lumfrac", po::value<T_float>()->default_value(0.01), 
     "ir luminosity percentile for including cells in variance check")
    ("tau_tol", po::value<T_float>()->default_value(100),
     "cell optical depth refinement tolerance")
    ("lambda_min", po::value<T_float>()->default_value(1e-7),
     "lower end of wavelength range")
    ("lambda_max", po::value<T_float>()->default_value(1e-3),
     "upper end of wavelength range")
    ("n_lambda", po::value<int>()->default_value(50),
     "number of wavelengths")
    ("seed", po::value<int>()->default_value(42), "random number seed")
    ("npix", po::value<int>()->default_value(50), "size of output image")
    ("wd01_parameter_set", po::value<string>()->default_value("DL07_MW3.1_60"), "The parameter set for WD01 grain models)")
    ("grain_data_dir", po::value<string>()->default_value("~/dust_data/crosssections"), "grain data directory")
    ("nscat_min", po::value<int>()->default_value(0), "min number of scatterings to register")
    ("nscat_max", po::value<int>()->default_value(blitz::huge(int())), "max number of scatterings to register")
    ("rt_method", po::value<string>()->default_value("standard"), "RT method used for primary radiation (other options are \"MLT\" and \"bidir\")")
    ("n_seeds", po::value<long>()->default_value(10000), "number of Metropolis seed paths")
    ("n_metro", po::value<long>()->default_value(100), "number of Metropolis runs")
    ("n_mut", po::value<long>()->default_value(1000000), "number of mutations per metropolis run")
    ("ns_max", po::value<int>()->default_value(blitz::huge(int())), "max number of source vertices when doing bidirectional tracing")
    ("nc_max", po::value<int>()->default_value(blitz::huge(int())), "max number of camera vertices when doing bidirectional tracing")
    ("reflambda", po::value<T_float>()->default_value(550e-9), "reference wavelength (only used for bidirectional tracing)")
    ("output-file", po::value<string>(), "output file name")
    ("help", "print this page")
    ;
}

int main (int argc, char** argv)
{
  mcrx::mpienv::init(argc, argv);
  hpm::disable();

  // Declare the supported options.
  po::options_description desc(description()+" Allowed options:");

  add_common_options(desc);
  add_custom_options(desc);
  po::positional_options_description p;
  p.add("output-file", 1);

  po::variables_map opt;
  po::store(po::command_line_parser(argc, argv).
	    options(desc).positional(p).run(), opt);
  po::notify(opt);    

  if (opt.count("help")) {
    cout << desc << "\n";
    return 1;
  }
  if(!opt.count("output-file")) {
    cout << "Must specify output file name on command line" << endl;
    return 1;
  }

  if(opt["rt_method"].as<string>()=="bidir" && 
     ( (opt["ns_max"].as<int>()<1) || 
       (opt["nc_max"].as<int>()<1))) {
    std::cerr << "Error: Both ns_max and nc_max must be at least one.\n";
    exit(1);
  }

  std::cout << description() << " Parameters:\n";
  print_common_options(opt);
  print_custom_options(opt);
  std::cout << std::endl;

  setup(opt);
}

#endif
